import paddle
import numpy as np
from paddle.vision.transforms import Normalize

transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 下载数据集并初始化 DataSet
'''
飞桨在 paddle.vision.datasets 下内置了计算机视觉（Computer Vision，CV）领域常见的数据集，
如 MNIST、Cifar10、Cifar100、FashionMNIST 和 VOC2012 等。在本任务中，
先后加载了 MNIST 训练集（mode='train'）和测试集（mode='test'），训练集用于训练模型，测试集用于评估模型效果。
'''
# 在初始化 MNIST 数据集时通过 transform 字段传入了 Normalize 变换对图像进行归一化，对图像进行归一化可以加快模型训练的收敛速度
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)

# 打印数据集里图片数量
print('{} images in train_dataset, {} images in test_dataset'.format(len(train_dataset), len(test_dataset)))

# 模型组网并初始化网络
'''
普通的神经网络就能达到很高的精度，在本任务中使用了飞桨内置的 LeNet 作为模型
num_classes 字段中定义分类的类别数，因为需要对 0 ~ 9 的十类数字进行分类，所以设置为 10
'''
lenet = paddle.vision.models.LeNet(num_classes=10)
# 可视化模型组网结构和参数
# paddle.summary(lenet,(1, 1, 28, 28))

# 封装模型 - 将网络结构组合成可快速使用 飞桨高层 API 进行训练、评估、推理的实例，方便后续操作
model = paddle.Model(lenet)

# 模型训练的配置准备，准备损失函数，优化器和评价指标
# 这里损失函数使用常见的 CrossEntropyLoss （交叉熵损失函数），优化器使用 Adam，评价指标使用 Accuracy 来计算模型在训练集上的精度
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())

# 配置循环参数，并启动训练
# 配置参数包括指定训练的数据源 train_dataset、训练的批大小 batch_size、训练轮数 epochs 等，执行后将自动完成模型的训练循环
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
# 模型评估 - 使用预先定义的测试数据集，来评估训练好的模型效果，评估完成后将输出模型在测试集上的损失函数值 loss 和精度 acc。
model.evaluate(test_dataset, batch_size=64, verbose=1)

# 保存模型，文件夹会自动创建
model.save('./output/mnist')
# 加载模型
model.load('output/mnist')

# 从测试集中取出一张图片
img, label = test_dataset[0]
# 将图片shape从1*28*28变为1*1*28*28，增加一个batch维度，以匹配模型输入格式要求
img_batch = np.expand_dims(img.astype('float32'), axis=0)

# 执行推理并打印结果，此处predict_batch返回的是一个list，取出其中数据获得预测结果
out = model.predict_batch(img_batch)[0]
pred_label = out.argmax()
print('true label: {}, pred label: {}'.format(label[0], pred_label))
# 可视化图片
from matplotlib import pyplot as plt

plt.imshow(img[0])
# 需要加上这句，否则不显示图片
plt.show()
